# Import libraries and custom functions
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import itertools
import gtsam
import gtsam.utils.plot
import utils
# Define plot image size
plt.rcParams['figure.figsize'] = (20, 12)
# Load images and convert to gray images
image_dir = 'data/test_buddha_images/'
image_list = []
bgr_image_list = []
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
for image_file_name in sorted(os.listdir(image_dir), reverse=True):
image = cv2.imread(image_dir+image_file_name)
bgr_image_list.append(image.astype(int))
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = clahe.apply(image)
image_list.append(image)
image_height, image_width = image_list[0].shape
# Assume camera intrinsic parameters
camera_matrix = np.array([[image_width, 0, image_width/2],
[0, image_width, image_height/2],
[0, 0, 1]])
# Evenly distribute the features across the image
def get_non_max_suppression_mask(keypoints):
binary_image = np.zeros((image_height, image_width))
response_list = np.array([keypoint.response for keypoint in keypoints])
mask = np.flip(np.argsort(response_list))
point_list = np.rint([keypoint.pt for keypoint in keypoints])[
mask].astype(int)
non_max_suppression_mask = []
for point, index in zip(point_list, mask):
if binary_image[point[1], point[0]] == 0:
non_max_suppression_mask.append(index)
cv2.circle(binary_image, (point[0], point[1]), 3, 255, -1)
return non_max_suppression_mask
# Detect keypoints and create descriptors of each images
sift = cv2.SIFT_create(nOctaveLayers=6)
keypoints_list = []
descriptors_list = []
object_index_list = []
for image in image_list:
keypoints, descriptors = sift.detectAndCompute(image, None)
non_max_suppression_mask = get_non_max_suppression_mask(keypoints)
keypoints_list.append(np.array(keypoints)[non_max_suppression_mask])
descriptors_list.append(np.array(descriptors)[non_max_suppression_mask])
object_index_list.append(np.full(len(keypoints), -1, int))
# Match the keypoints across 2 images
bf_matcher = cv2.BFMatcher(cv2.NORM_L2)
def get_match_points(src_keypoints, src_descriptors, dst_keypoints, dst_descriptors):
matches = bf_matcher.knnMatch(src_descriptors, dst_descriptors, k=2)
cross_matches = bf_matcher.match(dst_descriptors, src_descriptors)
cross_match_dict = {}
for cross_match in cross_matches:
cross_match_dict[cross_match.trainIdx] = cross_match.queryIdx
src_points = []
dst_points = []
src_point_index_list = []
dst_point_index_list = []
for match_1, match_2 in matches:
if match_1.distance < 0.75*match_2.distance:
src_points.append(src_keypoints[match_1.queryIdx].pt)
dst_points.append(dst_keypoints[match_1.trainIdx].pt)
src_point_index_list.append(match_1.queryIdx)
dst_point_index_list.append(match_1.trainIdx)
return np.array(src_points), np.array(dst_points), np.array(src_point_index_list), np.array(dst_point_index_list)
# Recover pose from matching points
def get_pose_index_mask(dst_points, src_points):
essential_matrix, essential_mask = cv2.findEssentialMat(
dst_points, src_points, camera_matrix, cv2.RANSAC, 0.999, 1.0)
essential_mask = np.squeeze(essential_mask != 0)
index_mask = np.arange(len(src_points))
index_mask = index_mask[essential_mask]
src_points = src_points[essential_mask]
dst_points = dst_points[essential_mask]
_, rotation_matrix, translation, pose_mask = cv2.recoverPose(
essential_matrix, dst_points, src_points, camera_matrix)
pose_mask = np.squeeze(pose_mask != 0)
return rotation_matrix, translation, index_mask[pose_mask]
# Initialize gtsam nonlinear factor graph
symbol_X = gtsam.symbol_shorthand.X
symbol_L = gtsam.symbol_shorthand.L
gtsam_camera_matrix = gtsam.Cal3_S2(
camera_matrix[0, 0], camera_matrix[1, 1], 0.0, camera_matrix[0, 2], camera_matrix[1, 2])
measurement_noise = gtsam.noiseModel.Isotropic.Sigma(2, 1.0)
pose_noise = gtsam.noiseModel.Diagonal.Sigmas(
np.array([0.3, 0.3, 0.3, 0.1, 0.1, 0.1]))
point_noise = gtsam.noiseModel.Isotropic.Sigma(3, 0.1)
graph = gtsam.NonlinearFactorGraph()
initial = gtsam.Values()
# Match the features of all consecutive images
src_points, dst_points, src_point_index_list, dst_point_index_list = get_match_points(
keypoints_list[0], descriptors_list[0], keypoints_list[1], descriptors_list[1])
rotation_matrix, translation, index_mask = get_pose_index_mask(
dst_points, src_points)
src_points = src_points[index_mask]
dst_points = dst_points[index_mask]
src_point_index_list = src_point_index_list[index_mask]
dst_point_index_list = dst_point_index_list[index_mask]
# utils.plot_match_points(image_list[0], image_list[1], src_points, dst_points)
# Triangulate points from 2 views
def get_object_points(src_projection_matrix, dst_projection_matrix, src_points, dst_points):
object_points = cv2.triangulatePoints(
src_projection_matrix, dst_projection_matrix, src_points.T, dst_points.T)
return (object_points/object_points[-1])[:-1].T
# get transformation matrix from rotation and translation
def get_transformation_matrix(rotation_matrix, translation):
transformation_matrix = np.eye(4)
transformation_matrix[:3, :] = np.hstack((rotation_matrix, translation))
return transformation_matrix
# get projection matrix from intrinsic and extrinsic parameters
def get_projection_matrix(transformation_matrix):
return camera_matrix.dot(np.linalg.inv(transformation_matrix)[:3])
# Push all the matching points into gtsam nonlinear factor graph
transformation_matrix = get_transformation_matrix(rotation_matrix, translation)
pose_factor = gtsam.PriorFactorPose3(symbol_X(0), gtsam.Pose3(), pose_noise)
graph.push_back(pose_factor)
initial.insert(symbol_X(0), gtsam.Pose3())
initial.insert(symbol_X(1),
gtsam.Pose3(gtsam.Rot3(rotation_matrix),
gtsam.Point3(translation.flatten())))
projection_matrix = get_projection_matrix(transformation_matrix)
sum_all_object_points = list(get_object_points(
camera_matrix.dot(np.hstack((np.eye(3), np.zeros((3, 1))))), projection_matrix, src_points, dst_points))
prev_transformation_matrix = transformation_matrix.copy()
prev_projection_matrix = projection_matrix.copy()
point_factor = gtsam.PriorFactorPoint3(
symbol_L(0), sum_all_object_points[0], point_noise)
graph.push_back(point_factor)
sum_all_object_points_color = []
all_object_points_count = []
src_object_index_array = object_index_list[0]
dst_object_index_array = object_index_list[1]
src_bgr_image = bgr_image_list[0]
dst_bgr_image = bgr_image_list[1]
rint_src_points = utils.get_rint(src_points)
rint_dst_points = utils.get_rint(dst_points)
for object_index, (src_point, dst_point, src_point_index, dst_point_index, object_point, rint_src_point, rint_dst_point) in enumerate(zip(src_points,
dst_points,
src_point_index_list,
dst_point_index_list,
sum_all_object_points,
rint_src_points,
rint_dst_points)):
src_object_index_array[src_point_index] = object_index
dst_object_index_array[dst_point_index] = object_index
graph.push_back(gtsam.GenericProjectionFactorCal3_S2(
src_point, measurement_noise, symbol_X(0), symbol_L(object_index), gtsam_camera_matrix))
graph.push_back(gtsam.GenericProjectionFactorCal3_S2(
dst_point, measurement_noise, symbol_X(1), symbol_L(object_index), gtsam_camera_matrix))
sum_all_object_points_color.append(
src_bgr_image[rint_src_point[1], rint_src_point[0]]+dst_bgr_image[rint_dst_point[1], rint_dst_point[0]])
all_object_points_count.append(2)
# Triangluate all the matching points and push into gtsam nonlinear factor graph
for src_index in range(1, len(image_list)-1):
dst_index = src_index+1
src_points, dst_points, src_point_index_list, dst_point_index_list = get_match_points(
keypoints_list[src_index], descriptors_list[src_index], keypoints_list[dst_index], descriptors_list[dst_index])
rotation_matrix, translation, index_mask = get_pose_index_mask(
dst_points, src_points)
src_points = src_points[index_mask]
dst_points = dst_points[index_mask]
src_point_index_list = src_point_index_list[index_mask]
dst_point_index_list = dst_point_index_list[index_mask]
utils.plot_match_points(
image_list[src_index], image_list[dst_index], src_points, dst_points)
transformation_matrix = get_transformation_matrix(
rotation_matrix, translation).dot(prev_transformation_matrix)
projection_matrix = get_projection_matrix(transformation_matrix)
match_src_points = []
match_dst_points = []
match_object_points = []
src_object_index_array = object_index_list[src_index]
dst_object_index_array = object_index_list[dst_index]
src_bgr_image = bgr_image_list[src_index]
dst_bgr_image = bgr_image_list[dst_index]
for src_point, dst_point, src_point_index in zip(src_points, dst_points, src_point_index_list):
object_index = src_object_index_array[src_point_index]
if object_index != -1:
match_src_points.append(src_point)
match_dst_points.append(dst_point)
match_object_points.append(
sum_all_object_points[object_index]/(all_object_points_count[object_index]-1))
match_src_points = np.array(match_src_points)
match_dst_points = np.array(match_dst_points)
match_object_points = np.array(match_object_points)
object_points = get_object_points(
prev_projection_matrix, projection_matrix, match_src_points, match_dst_points)
scale = 0
for match_object_point, object_point in zip(match_object_points, object_points):
scale += cv2.norm(match_object_point)/cv2.norm(object_point)
scale /= len(object_points)
translation *= scale
transformation_matrix = get_transformation_matrix(
rotation_matrix, translation).dot(prev_transformation_matrix)
initial.insert(symbol_X(dst_index),
gtsam.Pose3(gtsam.Rot3(transformation_matrix[:3, :3]),
gtsam.Point3(transformation_matrix[:3, 3])))
projection_matrix = get_projection_matrix(transformation_matrix)
object_points = get_object_points(
prev_projection_matrix, projection_matrix, src_points, dst_points)
prev_transformation_matrix = transformation_matrix.copy()
prev_projection_matrix = projection_matrix.copy()
rint_src_points = utils.get_rint(src_points)
rint_dst_points = utils.get_rint(dst_points)
current_object_index = len(all_object_points_count)
for src_point, dst_point, src_point_index, dst_point_index, object_point, rint_src_point, rint_dst_point in zip(src_points,
dst_points,
src_point_index_list,
dst_point_index_list,
object_points,
rint_src_points,
rint_dst_points):
object_index = src_object_index_array[src_point_index]
if object_index == -1:
object_index = current_object_index
sum_all_object_points.append(object_point)
sum_all_object_points_color.append(
src_bgr_image[rint_src_point[1], rint_src_point[0]] + dst_bgr_image[rint_dst_point[1], rint_dst_point[0]])
all_object_points_count.append(2)
src_object_index_array[src_point_index] = dst_object_index_array[
dst_point_index] = object_index
current_object_index += 1
else:
dst_object_index_array[dst_point_index] = object_index
sum_all_object_points[object_index] += object_point
sum_all_object_points_color[object_index] += dst_bgr_image[rint_dst_point[1],
rint_dst_point[0]]
all_object_points_count[object_index] += 1
graph.push_back(gtsam.GenericProjectionFactorCal3_S2(
src_point, measurement_noise, symbol_X(src_index), symbol_L(object_index), gtsam_camera_matrix))
graph.push_back(gtsam.GenericProjectionFactorCal3_S2(
dst_point, measurement_noise, symbol_X(dst_index), symbol_L(object_index), gtsam_camera_matrix))
# Run global optimization on gtsam nonlinear factor graph
all_object_points_count = np.array(
[all_object_points_count, all_object_points_count, all_object_points_count]).T
all_object_points = np.array(sum_all_object_points)/(all_object_points_count-1)
all_object_points_color = np.flip(np.array(
sum_all_object_points_color)/all_object_points_count).astype(np.uint8)
for object_index, object_point in enumerate(all_object_points):
initial.insert(symbol_L(object_index),
gtsam.Point3(object_point))
params = gtsam.LevenbergMarquardtParams()
optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial, params)
result = optimizer.optimize()
# Visulation
%matplotlib notebook
fig = plt.figure()
ax = fig.gca(projection='3d')
final_object_points = []
for index in range(len(all_object_points)):
final_object_points.append(result.atPoint3(symbol_L(index)))
final_object_points = np.array(final_object_points)
ax.scatter(final_object_points[:, 0],
final_object_points[:, 1], final_object_points[:, 2], c=all_object_points_color/255., s=3)
gtsam.utils.plot.plot_trajectory(1, result, scale=2)
gtsam.utils.plot.set_axes_equal(1)
ax.set_xlim3d(-40, 20)
ax.set_ylim3d(-30, 20)
ax.set_zlim3d(-10, 40)
plt.show()